In [1]:
# create entry points to spark
try:
    sc.stop()
except:
    pass
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
sc=SparkContext()
spark = SparkSession(sparkContext=sc)

udf() function and sql types`

The pyspark.sql.functions.udf() function is a very important function. It allows us to transfer a user defined function to a pyspark.sql.functions function which can act on columns of a DataFrame. It makes data framsformation much more flexible.

Using udf() could be tricky. The key is to understand how to define the returnType parameter.


In [2]:
from pyspark.sql.types import *
from pyspark.sql.functions import udf

In [3]:
mtcars = spark.read.csv('../../data/mtcars.csv', inferSchema=True, header=True)
mtcars = mtcars.withColumnRenamed('_c0', 'model')
mtcars.show(5)


+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
|            model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
|        Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|
|    Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|
|       Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|
|   Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|
|Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
only showing top 5 rows

The structure of the schema passed to returnType has to match the data structure of the return value from the user defined function.

Case 1: divide disp by hp and put the result to a new column

The user defined function returns a float value.


In [4]:
def disp_by_hp(disp, hp):
    return(disp/hp)

In [5]:
disp_by_hp_udf = udf(disp_by_hp, returnType=FloatType())

In [8]:
all_original_cols = [eval('mtcars.' + x) for x in mtcars.columns]
all_original_cols


Out[8]:
[Column<b'model'>,
 Column<b'mpg'>,
 Column<b'cyl'>,
 Column<b'disp'>,
 Column<b'hp'>,
 Column<b'drat'>,
 Column<b'wt'>,
 Column<b'qsec'>,
 Column<b'vs'>,
 Column<b'am'>,
 Column<b'gear'>,
 Column<b'carb'>]

In [9]:
disp_by_hp_col = disp_by_hp_udf(mtcars.disp, mtcars.hp)
disp_by_hp_col


Out[9]:
Column<b'disp_by_hp(disp, hp)'>

In [11]:
all_new_cols = all_original_cols + [disp_by_hp_col]
all_new_cols


Out[11]:
[Column<b'model'>,
 Column<b'mpg'>,
 Column<b'cyl'>,
 Column<b'disp'>,
 Column<b'hp'>,
 Column<b'drat'>,
 Column<b'wt'>,
 Column<b'qsec'>,
 Column<b'vs'>,
 Column<b'am'>,
 Column<b'gear'>,
 Column<b'carb'>,
 Column<b'disp_by_hp(disp, hp)'>]

In [12]:
mtcars.select(all_new_cols).show()


+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+
|              model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|disp_by_hp(disp, hp)|
+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+
|          Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|           1.4545455|
|      Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|           1.4545455|
|         Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|           1.1612903|
|     Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|           2.3454545|
|  Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2|            2.057143|
|            Valiant|18.1|  6|225.0|105|2.76| 3.46|20.22|  1|  0|   3|   1|            2.142857|
|         Duster 360|14.3|  8|360.0|245|3.21| 3.57|15.84|  0|  0|   3|   4|           1.4693878|
|          Merc 240D|24.4|  4|146.7| 62|3.69| 3.19| 20.0|  1|  0|   4|   2|            2.366129|
|           Merc 230|22.8|  4|140.8| 95|3.92| 3.15| 22.9|  1|  0|   4|   2|           1.4821053|
|           Merc 280|19.2|  6|167.6|123|3.92| 3.44| 18.3|  1|  0|   4|   4|           1.3626016|
|          Merc 280C|17.8|  6|167.6|123|3.92| 3.44| 18.9|  1|  0|   4|   4|           1.3626016|
|         Merc 450SE|16.4|  8|275.8|180|3.07| 4.07| 17.4|  0|  0|   3|   3|           1.5322223|
|         Merc 450SL|17.3|  8|275.8|180|3.07| 3.73| 17.6|  0|  0|   3|   3|           1.5322223|
|        Merc 450SLC|15.2|  8|275.8|180|3.07| 3.78| 18.0|  0|  0|   3|   3|           1.5322223|
| Cadillac Fleetwood|10.4|  8|472.0|205|2.93| 5.25|17.98|  0|  0|   3|   4|            2.302439|
|Lincoln Continental|10.4|  8|460.0|215| 3.0|5.424|17.82|  0|  0|   3|   4|            2.139535|
|  Chrysler Imperial|14.7|  8|440.0|230|3.23|5.345|17.42|  0|  0|   3|   4|           1.9130435|
|           Fiat 128|32.4|  4| 78.7| 66|4.08|  2.2|19.47|  1|  1|   4|   1|           1.1924243|
|        Honda Civic|30.4|  4| 75.7| 52|4.93|1.615|18.52|  1|  1|   4|   2|           1.4557692|
|     Toyota Corolla|33.9|  4| 71.1| 65|4.22|1.835| 19.9|  1|  1|   4|   1|           1.0938462|
+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+--------------------+
only showing top 20 rows

case 2: create an array column that contain disp and hp values


In [34]:
# define function
def merge_two_columns(col1, col2):
    return([float(col1), float(col2)])

# convert user defined function into an udf function (sql function)
array_merge_two_columns_udf = udf(merge_two_columns, returnType=ArrayType(FloatType()))

In [35]:
array_col = array_merge_two_columns_udf(mtcars.disp, mtcars.hp)
array_col


Out[35]:
Column<b'merge_two_columns(disp, hp)'>

In [36]:
all_new_cols = all_original_cols + [array_col]
all_new_cols


Out[36]:
[Column<b'model'>,
 Column<b'mpg'>,
 Column<b'cyl'>,
 Column<b'disp'>,
 Column<b'hp'>,
 Column<b'drat'>,
 Column<b'wt'>,
 Column<b'qsec'>,
 Column<b'vs'>,
 Column<b'am'>,
 Column<b'gear'>,
 Column<b'carb'>,
 Column<b'merge_two_columns(disp, hp)'>]

In [37]:
mtcars.select(all_new_cols).show(5, truncate=False)


+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+
|model            |mpg |cyl|disp |hp |drat|wt   |qsec |vs |am |gear|carb|merge_two_columns(disp, hp)|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+
|Mazda RX4        |21.0|6  |160.0|110|3.9 |2.62 |16.46|0  |1  |4   |4   |[160.0, 110.0]             |
|Mazda RX4 Wag    |21.0|6  |160.0|110|3.9 |2.875|17.02|0  |1  |4   |4   |[160.0, 110.0]             |
|Datsun 710       |22.8|4  |108.0|93 |3.85|2.32 |18.61|1  |1  |4   |1   |[108.0, 93.0]              |
|Hornet 4 Drive   |21.4|6  |258.0|110|3.08|3.215|19.44|1  |0  |3   |1   |[258.0, 110.0]             |
|Hornet Sportabout|18.7|8  |360.0|175|3.15|3.44 |17.02|0  |0  |3   |2   |[360.0, 175.0]             |
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------------------------+
only showing top 5 rows

ArrayType vs. StructType

Both ArrayType and StructType can be used to build returnType for a list. The difference is:

  1. ArrayType requires all elements in the list have the same elementType, while StructType can have different elementTypes.
  2. StructType represents a Row object.

Define an ArrayType with elementType being FloatType.


In [38]:
# define function
def merge_two_columns(col1, col2):
    return([float(col1), float(col2)])
array_type = ArrayType(FloatType())
array_merge_two_columns_udf = udf(merge_two_columns, returnType=array_type)

Define a StructType with one elementType being StringType and the other being FloatType.


In [50]:
# define function
def merge_two_columns(col1, col2):
    return([str(col1), float(col2)])
struct_type = StructType([
    StructField('f1', StringType()),
    StructField('f2', FloatType())
])
struct_merge_two_columns_udf = udf(merge_two_columns, returnType=struct_type)

array column expression: both values are float type values


In [52]:
array_col = array_merge_two_columns_udf(mtcars.hp, mtcars.disp)
array_col


Out[52]:
Column<b'merge_two_columns(hp, disp)'>

struct column expression: first value is a string and the second value is a float type value.


In [53]:
struct_col = struct_merge_two_columns_udf(mtcars.model, mtcars.disp)
struct_col


Out[53]:
Column<b'merge_two_columns(model, disp)'>

Results


In [56]:
mtcars.select(array_col, struct_col).show(truncate=False)


+---------------------------+------------------------------+
|merge_two_columns(hp, disp)|merge_two_columns(model, disp)|
+---------------------------+------------------------------+
|[110.0, 160.0]             |[Mazda RX4, 160.0]            |
|[110.0, 160.0]             |[Mazda RX4 Wag, 160.0]        |
|[93.0, 108.0]              |[Datsun 710, 108.0]           |
|[110.0, 258.0]             |[Hornet 4 Drive, 258.0]       |
|[175.0, 360.0]             |[Hornet Sportabout, 360.0]    |
|[105.0, 225.0]             |[Valiant, 225.0]              |
|[245.0, 360.0]             |[Duster 360, 360.0]           |
|[62.0, 146.7]              |[Merc 240D, 146.7]            |
|[95.0, 140.8]              |[Merc 230, 140.8]             |
|[123.0, 167.6]             |[Merc 280, 167.6]             |
|[123.0, 167.6]             |[Merc 280C, 167.6]            |
|[180.0, 275.8]             |[Merc 450SE, 275.8]           |
|[180.0, 275.8]             |[Merc 450SL, 275.8]           |
|[180.0, 275.8]             |[Merc 450SLC, 275.8]          |
|[205.0, 472.0]             |[Cadillac Fleetwood, 472.0]   |
|[215.0, 460.0]             |[Lincoln Continental, 460.0]  |
|[230.0, 440.0]             |[Chrysler Imperial, 440.0]    |
|[66.0, 78.7]               |[Fiat 128, 78.7]              |
|[52.0, 75.7]               |[Honda Civic, 75.7]           |
|[65.0, 71.1]               |[Toyota Corolla, 71.1]        |
+---------------------------+------------------------------+
only showing top 20 rows